package com.kool.kmqtt.server.repository.local;

import com.kool.kmqtt.server.packet.Packet;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.LinkedBlockingQueue;

import static com.kool.kmqtt.util.TopicUtil.getNextItmes;
import static com.kool.kmqtt.util.TopicUtil.split;

/**
 * 保留消息树
 */
public class RetainTrees {
    private static final RetainTrees self = new RetainTrees();

    private RetainTrees() {
    }

    public static RetainTrees getInstance() {
        return self;
    }

    private ConcurrentHashMap<String, Node> trees = new ConcurrentHashMap<>();

    private class Node {
        /**
         * 子树表
         * key：主题分割后的一段层级名
         * value:子树的根节点
         */
        ConcurrentHashMap<String, Node> children = new ConcurrentHashMap<>();
        /**
         * 保留消息队列
         */

        LinkedBlockingQueue<Packet> retainPackets = new LinkedBlockingQueue<>();
    }

    /**
     * 增加保留消息
     *
     * @param topicName
     * @param packet
     */
    public void add(String topicName, Packet packet) {
        //分割主题
        String[] topicItems = split(topicName);
        //逐层匹配，如果不存在则插入，如果路径完整匹配，则更新
        updateNode(trees, topicItems, packet);
    }

    /**
     * 逐层匹配，如果不存在则插入，如果路径完整匹配，则更新
     *
     * @param children
     * @param topicItems
     * @param packet
     */
    private void updateNode(ConcurrentHashMap<String, Node> children, String[] topicItems, Packet packet) {
        //判断是否存在
        Node node = children.get(topicItems[0]);
        if (node == null) {
            //如果不存在，则需要根据topic逐层构造节点
            String[] nextItems = getNextItmes(topicItems);
            children.put(topicItems[0], buildNode(nextItems, packet));
        } else {
            String[] nextItems = getNextItmes(topicItems);
            if (nextItems == null) {
                //递归结束条件
                node.retainPackets.offer(packet);
            } else {
                updateNode(node.children, nextItems, packet);
            }

        }
    }

    /**
     * 根据topic逐层构造节点
     *
     * @param topicItems
     * @param packet
     * @return
     */
    private Node buildNode(String[] topicItems, Packet packet) {
        Node node = new Node();
        if (topicItems == null) {
            //递归结束条件
            node.retainPackets = new LinkedBlockingQueue<>();
            node.retainPackets.offer(packet);
        } else {
            String[] nextItems = getNextItmes(topicItems);
            node.children.put(topicItems[0], buildNode(nextItems, packet));
        }
        return node;
    }

    /**
     * 删除主题下的所有保留消息
     *
     * @param topicName
     */
    public void delete(String topicName) {
        //分割主题
        String[] topicItems = split(topicName);
        //递归扫描，删除订阅信息
        deleteScan(trees, topicItems);
    }

    private void deleteScan(ConcurrentHashMap<String, Node> children, String[] topicItems) {
        if (topicItems == null || children.get(topicItems[0]) == null) {
            return;
        }
        if (topicItems.length == 1) {
            children.get(topicItems[0]).retainPackets.clear();
        } else {
            deleteScan(children.get(topicItems[0]).children, getNextItmes(topicItems));
        }
    }

    /**
     * 获取与主题过滤器匹配的所有保留消息
     *
     * @param topicFilter
     * @return
     */
    public List<Packet> match(String topicFilter) {
        //分割主题过滤器
        String[] topicFilterItems = split(topicFilter);
        //逐层扫描
        List<Packet> retainPackets = matchScan(trees, topicFilterItems);
        return retainPackets;
    }

    private List<Packet> matchScan(ConcurrentHashMap<String, Node> children, String[] topicFilterItems) {
        List<Packet> packets = new ArrayList<>();
        if (children.size() == 0 || topicFilterItems == null) {
            return packets;
        }

        if (topicFilterItems.length >= 2 && "#".equals(topicFilterItems[1])) {
            if ("+".equals(topicFilterItems[0])) {
                for (Node node : children.values()) {
                    packets.addAll(node.retainPackets);
                    String[] nextItems = new String[1];
                    nextItems[0] = "#";
                    packets.addAll(matchScan(node.children, nextItems));
                }
            }else{
                Node node=  children.get(topicFilterItems[0]);
                packets.addAll(node.retainPackets);
                String[] nextItems = new String[1];
                nextItems[0] = "#";
                packets.addAll(matchScan(node.children, nextItems));
            }
        } else if ("#".equals(topicFilterItems[0])) {
            //这层级的所有节点及子节点都加入返回集合
            for (Node node : children.values()) {
                packets.addAll(node.retainPackets);
                String[] nextItems = new String[1];
                nextItems[0] = "#";
                packets.addAll(matchScan(node.children, nextItems));
            }
        } else {
            for (String nodeKey : children.keySet()) {
                if (("+".equals(topicFilterItems[0])
                        || nodeKey.equals(topicFilterItems[0])) && topicFilterItems.length == 1) {
                    //如果层级匹配且是主题过滤器最后一层，节点数据加入返回集合
                    packets.addAll(children.get(nodeKey).retainPackets);
                } else if (("+".equals(topicFilterItems[0])
                        || nodeKey.equals(topicFilterItems[0])) && topicFilterItems.length > 1) {
                    //如果层级匹配且主题过滤器还有后续层级，继续扫描
                    packets.addAll(matchScan(children.get(nodeKey).children, getNextItmes(topicFilterItems)));
                }
            }
        }

        return packets;
    }
}
